## Loading required package: lpSolve
##
## Attaching package: 'salso'
## The following object is masked from 'package:mcclust':
##
## binder
##
## Attaching package: 'ggpubr'
## The following object is masked from 'package:WASABI':
##
## ggscatter
##
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
##
## filter, lag
## The following objects are masked from 'package:base':
##
## intersect, setdiff, setequal, union
The data is from Chen et al (2019) describing single-neuron axon projection counts to 11 brain areas: the orbitofrontal cortex (OFC), motor cortex (Motor), rostral striatum (Rstr), somatosensory cortex (SSctx), caudal striatum (Cstr), amygdala (Amyg), ipsilateral visual cortex (VisIp), contralateral visual cortex (VisC), contralateral auditory cortex (AudC), thalamus (Thal), and tectum (Tect). Data are collected across three brains (mice). For illustration, we focus on the first and third brain which are extracted using the same technology (BARseq).
data("data_barseq")
data_barseq = list(data_barseq[[1]], data_barseq[[3]])
M <- length(data_barseq)
R <- nrow(data_barseq[[1]])
regions.name <- rownames(data_barseq[[1]])
C <- sapply(1:M, function(m) ncol(data_barseq[[m]]))
mouse.index <- c(rep(1, C[1]),
rep(2, C[2]))
Let’s visualize a heatmap of the data. Data are normalized by the total counts to refect projection strength.
HBMAP employs a hierarchcal mixture of Dirichlet-Multinomials to model axon projection data.
# Set the truncation
J = 35
# ---- parameters to pass to the main function ------
# mcmc setup
mcmc_list = list(number_iter = 20000, thinning = 5, burn_in = 10000, adaptive_prop = 0.0001,
auto_save = FALSE,
save_path = NULL,
save_frequency = 1000
)
# prior parameters, default values will be used if not provided
prior_list = list(a_gamma = 20, b_gamma = 1, lb_gamma = 1,
a = 1, tau = 0.2, nu = 1/1000,
a_alpha = 10, b_alpha = 1, a_alpha0 = 5, b_alpha0 = 1)
# Initialization
set.seed(43)
Z.init <- k_means_axon(Y = data_barseq, k = 30, transformation = 'cosine', restart = 50, iter.max = 100)
# ------- Run the full model ---------
seeds = c(101, 112, 323, 141, 555)
mcmc_all_barseq = parallel::mclapply(1:5,
function(g){
set.seed(g)
HBMAP_mcmc(Y = data_barseq, J = J, mcmc = mcmc_list,
prior = prior_list, Z.init = Z.init, verbose = TRUE)
},
mc.cores = 5)
cls.draw.list = lapply(1:5, function(g){
d = mcmc_all_barseq[[g]]$Z_output
matrix(unlist(d),length(d), sum(C), byrow = TRUE)
})
cls.draw = do.call(rbind, cls.draw.list)
Let’s start by considering the optimal clustering solution obtained with different loss functions as well as the marginal posterior on the number of clusters.
# Relabel
relabel = function(c){
uu = unique(c)
c2 = c
for(i in 1:length(uu)){
c2[c == uu[i]] = i
}
c2
}
cls.draw = t(apply(cls.draw,1,relabel))
S = dim(cls.draw)[1]
# Marginal posterior on the number of clusters
K.draw = apply(cls.draw,1,max)
# Compute psm
psm = mcclust::comp.psm(cls.draw)
# Estimate clustering and compare different loss functions
#VI
output_salso = salso(x = cls.draw, maxZealousAttempts=50)
# Binder's loss
output_salso_binder = salso(x = cls.draw, loss = "binder", maxNClusters = 100, maxZealousAttempts=50)
# ARI
output_salso_ari = salso(x = cls.draw, loss = "omARI",maxNClusters = 100, maxZealousAttempts=50)
# Marginal posterior on the number of clusters
ggplot()+
geom_bar(aes(x=K.draw))+
theme_bw() +
labs(x="number of clusters")
print(paste('Mean of marginal posterior on the number of clusters:', mean(K.draw)))
## [1] "Mean of marginal posterior on the number of clusters: 26.8953523238381"
#VI
print(paste('Number of clusters in vi estimate:', length(unique(output_salso))))
## [1] "Number of clusters in vi estimate: 27"
# Illustrate clustering with psm
superheat(psm,
pretty.order.rows = TRUE,
pretty.order.cols = TRUE,
heat.pal = c("white", "yellow", "red"),
heat.pal.values = c(0,.5,1),
membership.rows = output_salso,
membership.cols = output_salso,
bottom.label.text.size = 4,
left.label.text.size = 4)
# Illustrate clustering with heatmap of row-normalized data
vi_list = lapply(1:M,function(m){output_salso[mouse.index==m]})
heatmap_ps(Y = data_barseq, Z = vi_list, regions.name = rownames(data_barseq[[1]]),
group.index = mouse.index, group.name = 'brain',
cluster.index = 1:length(unique(output_salso)), title = '')
# Binder's loss
print(paste('Number of clusters in binders estimate:', length(unique(output_salso_binder))))
## [1] "Number of clusters in binders estimate: 50"
# Illustrate clustering with psm
superheat(psm,
pretty.order.rows = TRUE,
pretty.order.cols = TRUE,
heat.pal = c("white", "yellow", "red"),
heat.pal.values = c(0,.5,1),
membership.rows = output_salso_binder,
membership.cols = output_salso_binder,
bottom.label.text.size = 4,
left.label.text.size = 4)
# Illustrate clustering with heatmap of row-normalized data
binder_list = lapply(1:M,function(m){output_salso_binder[mouse.index==m]})
heatmap_ps(Y = data_barseq, Z = binder_list, regions.name = rownames(data_barseq[[1]]),
group.index = mouse.index, group.name = 'brain',
cluster.index = 1:length(unique(output_salso_binder)), title = '')
# ARI
print(paste('Number of clusters in ari estimate:', length(unique(output_salso_ari))))
## [1] "Number of clusters in ari estimate: 37"
# Illustrate clustering with psm
superheat(psm,
pretty.order.rows = TRUE,
pretty.order.cols = TRUE,
heat.pal = c("white", "yellow", "red"),
heat.pal.values = c(0,.5,1),
membership.rows = output_salso_ari,
membership.cols = output_salso_ari,
bottom.label.text.size = 4,
left.label.text.size = 4)
# Illustrate clustering with heatmap of row-normalized data
ari_list = lapply(1:M,function(m){output_salso_ari[mouse.index==m]})
heatmap_ps(Y = data_barseq, Z = ari_list, regions.name = rownames(data_barseq[[1]]),
group.index = mouse.index, group.name = 'brain',
cluster.index = 1:length(unique(output_salso_ari)), title = '')
Different results are obtained in this case. Binder and ARI lead to a large number of clusters, with many small clusters. VI is more parsimonious. Let’s summarize with WASABI to better understand if there are multiple modes of clustering.
We use the elbow function to choose the number of
particles \(L\) with the elbow
method:
set.seed(123)
L_max = 10
tic()
out_elbow <- elbow(cls.draw, L_max = L_max, multi.start = 4,
method.init = "++",
mini.batch = 500, max.iter = 20, extra.iter = 4,
method = "salso")
## Completed 1 / 10
## Completed 2 / 10
## Completed 3 / 10
## Completed 4 / 10
## Completed 5 / 10
## Completed 6 / 10
## Completed 7 / 10
## Completed 8 / 10
## Completed 9 / 10
## Completed 10 / 10
toc()
## 3042.824 sec elapsed
L= 3
ggplot() +
geom_point(aes(x=c(1:L_max), y=out_elbow$wass_vec)) +
geom_line(aes(x=c(1:L_max), y=out_elbow$wass_vec)) +
labs(x="Number of particles", y="Wasserstein distance") +
annotate("point", x = L, y = out_elbow$wass_vec[L], color = "red", shape = 1, size = 3) +
theme_bw()
Once the value of \(L\) is chosen, we can run another set of initializations to see if we can find a better approximation:
tic()
output_WASABI_mb = WASABI_multistart(cls.draw, psm,
multi.start = 50, ncores = 5,
method.init ="++", add_topvi = FALSE,
method="salso", L=L,
mini.batch = 500,
max.iter= 20, extra.iter = 10,
swap_countone = TRUE,
maxNClusters = 45, maxZealousAttempts=20,
seed = 54321)
toc()
## 9156.843 sec elapsed
output_WASABI <- out_elbow$output_list[[L]]
if(output_WASABI_mb$wass.dist < output_WASABI$wass.dist){
output_WASABI <- output_WASABI_mb
print(paste('Improved approximation with multiple initialization: Wass dist =',output_WASABI$wass.dist))
}
## [1] "Improved approximation with multiple initialization: Wass dist = 0.937498746173342"
tic()
output_WASABI_avg = WASABI(cls.draw, psm, method.init ="average",
method="salso", L=L, mini.batch = 500,
maxNClusters = 45, maxZealousAttempts=20,
max.iter= 20, extra.iter = 10, swap_countone = TRUE,
suppress.comment = FALSE)
## Initial particle 1 : number of clusters = 24 , EVI = 0.987
## Initial particle 2 : number of clusters = 23 , EVI = 0.994
## Initial particle 3 : number of clusters = 22 , EVI = 0.997
## Iteration = 1
## Particle 1 : number of clusters=27 , EVI = 0.972 , sumVI = 0.745 , w= 0.766
## Particle 2 : number of clusters=29 , EVI = 0.992 , sumVI = 0.212 , w= 0.214
## Particle 3 : number of clusters=26 , EVI = 0.899 , sumVI = 0.018 , w= 0.02
## Wasserstein dist = 0.975023835414756
## Iteration = 2
## Particle 1 : number of clusters=27 , EVI = 0.959 , sumVI = 0.593 , w= 0.618
## Particle 2 : number of clusters=28 , EVI = 0.994 , sumVI = 0.316 , w= 0.318
## Particle 3 : number of clusters=25 , EVI = 0.907 , sumVI = 0.058 , w= 0.064
## Wasserstein dist = 0.96659155909822
## Iteration = 3
## Particle 1 : number of clusters=27 , EVI = 0.944 , sumVI = 0.391 , w= 0.414
## Particle 2 : number of clusters=28 , EVI = 0.995 , sumVI = 0.422 , w= 0.424
## Particle 3 : number of clusters=25 , EVI = 0.898 , sumVI = 0.146 , w= 0.162
## Wasserstein dist = 0.958539135616072
## Iteration = 4
## Particle 1 : number of clusters=27 , EVI = 0.953 , sumVI = 0.434 , w= 0.456
## Particle 2 : number of clusters=28 , EVI = 1.002 , sumVI = 0.379 , w= 0.378
## Particle 3 : number of clusters=25 , EVI = 0.919 , sumVI = 0.153 , w= 0.166
## Wasserstein dist = 0.965515185543286
## Iteration = 5
## Particle 1 : number of clusters=27 , EVI = 0.942 , sumVI = 0.463 , w= 0.492
## Particle 2 : number of clusters=27 , EVI = 1.001 , sumVI = 0.354 , w= 0.354
## Particle 3 : number of clusters=25 , EVI = 0.92 , sumVI = 0.142 , w= 0.154
## Wasserstein dist = 0.959562766349092
## Iteration = 6
## Particle 1 : number of clusters=29 , EVI = 0.947 , sumVI = 0.494 , w= 0.522
## Particle 2 : number of clusters=26 , EVI = 0.999 , sumVI = 0.328 , w= 0.328
## Particle 3 : number of clusters=26 , EVI = 0.909 , sumVI = 0.136 , w= 0.15
## Wasserstein dist = 0.958268319961711
## Iteration = 7
## Particle 1 : number of clusters=29 , EVI = 0.963 , sumVI = 0.481 , w= 0.5
## Particle 2 : number of clusters=25 , EVI = 1.007 , sumVI = 0.304 , w= 0.302
## Particle 3 : number of clusters=26 , EVI = 0.92 , sumVI = 0.182 , w= 0.198
## Wasserstein dist = 0.967676608111123
## Iteration = 8
## Particle 1 : number of clusters=28 , EVI = 0.962 , sumVI = 0.523 , w= 0.544
## Particle 2 : number of clusters=25 , EVI = 0.997 , sumVI = 0.255 , w= 0.256
## Particle 3 : number of clusters=26 , EVI = 0.932 , sumVI = 0.186 , w= 0.2
## Wasserstein dist = 0.964867984158873
## Iteration = 9
## Particle 1 : number of clusters=27 , EVI = 0.953 , sumVI = 0.481 , w= 0.504
## Particle 2 : number of clusters=27 , EVI = 0.989 , sumVI = 0.289 , w= 0.292
## Particle 3 : number of clusters=25 , EVI = 0.923 , sumVI = 0.188 , w= 0.204
## Wasserstein dist = 0.957460908220007
## Iteration = 10
## Particle 1 : number of clusters=27 , EVI = 0.964 , sumVI = 0.482 , w= 0.5
## Particle 2 : number of clusters=26 , EVI = 0.987 , sumVI = 0.286 , w= 0.29
## Particle 3 : number of clusters=26 , EVI = 0.932 , sumVI = 0.196 , w= 0.21
## Wasserstein dist = 0.963786856334065
## Iteration = 11
## Particle 1 : number of clusters=27 , EVI = 0.95 , sumVI = 0.479 , w= 0.504
## Particle 2 : number of clusters=26 , EVI = 0.995 , sumVI = 0.294 , w= 0.296
## Particle 3 : number of clusters=26 , EVI = 0.904 , sumVI = 0.181 , w= 0.2
## Wasserstein dist = 0.954280828876736
## Iteration = 12
## Particle 1 : number of clusters=28 , EVI = 0.953 , sumVI = 0.545 , w= 0.572
## Particle 2 : number of clusters=26 , EVI = 0.987 , sumVI = 0.259 , w= 0.262
## Particle 3 : number of clusters=27 , EVI = 0.912 , sumVI = 0.151 , w= 0.166
## Wasserstein dist = 0.955295323499047
## *Running full batch after mini-batch*
## Iteration = 13
## Particle 1 : number of clusters=28 , EVI = 0.953 , sumVI = 0.498 , w= 0.523
## Particle 2 : number of clusters=26 , EVI = 0.991 , sumVI = 0.275 , w= 0.277
## Particle 3 : number of clusters=26 , EVI = 0.918 , sumVI = 0.183 , w= 0.199
## Wasserstein dist = 0.956424856071233
## Iteration = 14
## Particle 1 : number of clusters=28 , EVI = 0.951 , sumVI = 0.493 , w= 0.518
## Particle 2 : number of clusters=26 , EVI = 0.992 , sumVI = 0.278 , w= 0.28
## Particle 3 : number of clusters=26 , EVI = 0.918 , sumVI = 0.186 , w= 0.202
## Wasserstein dist = 0.95604751395645
print(paste('Average initialization: Wass dist =',output_WASABI_avg$wass.dist))
## [1] "Average initialization: Wass dist = 0.95604751395645"
toc()
## 165.694 sec elapsed
tic()
output_WASABI_comp = WASABI(cls.draw, psm, method.init ="complete",
method="salso", L=L, mini.batch = 500,
maxNClusters = 45, maxZealousAttempts=20,
max.iter= 20, extra.iter = 10, swap_countone = TRUE,
suppress.comment = FALSE)
## Initial particle 1 : number of clusters = 23 , EVI = 1.032
## Initial particle 2 : number of clusters = 22 , EVI = 1.04
## Initial particle 3 : number of clusters = 26 , EVI = 1.045
## Iteration = 1
## Particle 1 : number of clusters=28 , EVI = 0.947 , sumVI = 0.428 , w= 0.452
## Particle 2 : number of clusters=28 , EVI = 0.956 , sumVI = 0.163 , w= 0.17
## Particle 3 : number of clusters=29 , EVI = 0.984 , sumVI = 0.372 , w= 0.378
## Wasserstein dist = 0.962473273883556
## Iteration = 2
## Particle 1 : number of clusters=27 , EVI = 0.949 , sumVI = 0.44 , w= 0.464
## Particle 2 : number of clusters=27 , EVI = 0.954 , sumVI = 0.193 , w= 0.202
## Particle 3 : number of clusters=30 , EVI = 0.969 , sumVI = 0.324 , w= 0.334
## Wasserstein dist = 0.956685416838028
## Iteration = 3
## Particle 1 : number of clusters=27 , EVI = 0.942 , sumVI = 0.563 , w= 0.598
## Particle 2 : number of clusters=26 , EVI = 0.951 , sumVI = 0.192 , w= 0.202
## Particle 3 : number of clusters=30 , EVI = 0.932 , sumVI = 0.186 , w= 0.2
## Wasserstein dist = 0.941559358077484
## Iteration = 4
## Particle 1 : number of clusters=26 , EVI = 0.939 , sumVI = 0.533 , w= 0.568
## Particle 2 : number of clusters=24 , EVI = 0.967 , sumVI = 0.234 , w= 0.242
## Particle 3 : number of clusters=31 , EVI = 0.922 , sumVI = 0.175 , w= 0.19
## Wasserstein dist = 0.942299756511466
## *Running full batch after mini-batch*
## Iteration = 5
## Particle 1 : number of clusters=27 , EVI = 0.942 , sumVI = 0.527 , w= 0.559
## Particle 2 : number of clusters=25 , EVI = 0.967 , sumVI = 0.241 , w= 0.249
## Particle 3 : number of clusters=31 , EVI = 0.921 , sumVI = 0.177 , w= 0.192
## Wasserstein dist = 0.944386961370636
## Iteration = 6
## Particle 1 : number of clusters=27 , EVI = 0.943 , sumVI = 0.547 , w= 0.581
## Particle 2 : number of clusters=25 , EVI = 0.97 , sumVI = 0.226 , w= 0.233
## Particle 3 : number of clusters=31 , EVI = 0.912 , sumVI = 0.17 , w= 0.186
## Wasserstein dist = 0.943234195737468
## Iteration = 7
## Particle 1 : number of clusters=27 , EVI = 0.943 , sumVI = 0.559 , w= 0.593
## Particle 2 : number of clusters=26 , EVI = 0.97 , sumVI = 0.217 , w= 0.223
## Particle 3 : number of clusters=31 , EVI = 0.909 , sumVI = 0.167 , w= 0.183
## Wasserstein dist = 0.942959596749585
print(paste('Complete initialization: Wass dist =',output_WASABI_comp$wass.dist))
## [1] "Complete initialization: Wass dist = 0.942959596749585"
toc()
## 240.105 sec elapsed
part.init = matrix(0, L, sum(C))
nclus = c(25,28,32)
for (l in c(1:L)){
part.init[l,] = as.numeric(salso(x = cls.draw, loss = "binder", maxNClusters = nclus[l]))
}
tic()
output_WASABI_fxd = WASABI(cls.draw, psm, method.init ="fixed", part.init = part.init,
method="salso", L=L,
maxNClusters = 45, maxZealousAttempts=20,
max.iter= 30, swap_countone = TRUE, suppress.comment = FALSE)
## Initial particle 1 : number of clusters = 25 , EVI = 1.082
## Initial particle 2 : number of clusters = 28 , EVI = 1.067
## Initial particle 3 : number of clusters = 32 , EVI = 1.062
## Iteration = 1
## Particle 1 : number of clusters=26 , EVI = 0.911 , sumVI = 0.003 , w= 0.004
## Particle 2 : number of clusters=27 , EVI = 0.967 , sumVI = 0.307 , w= 0.318
## Particle 3 : number of clusters=28 , EVI = 0.981 , sumVI = 0.666 , w= 0.678
## Wasserstein dist = 0.976621847585401
## Iteration = 2
## Particle 1 : number of clusters=25 , EVI = 0.982 , sumVI = 0.197 , w= 0.201
## Particle 2 : number of clusters=27 , EVI = 0.974 , sumVI = 0.382 , w= 0.392
## Particle 3 : number of clusters=28 , EVI = 0.958 , sumVI = 0.389 , w= 0.406
## Wasserstein dist = 0.969011163807456
## Iteration = 3
## Particle 1 : number of clusters=25 , EVI = 0.999 , sumVI = 0.282 , w= 0.282
## Particle 2 : number of clusters=27 , EVI = 0.962 , sumVI = 0.346 , w= 0.36
## Particle 3 : number of clusters=28 , EVI = 0.948 , sumVI = 0.339 , w= 0.358
## Wasserstein dist = 0.967320157940298
## Iteration = 4
## Particle 1 : number of clusters=26 , EVI = 1.005 , sumVI = 0.279 , w= 0.277
## Particle 2 : number of clusters=27 , EVI = 0.959 , sumVI = 0.349 , w= 0.364
## Particle 3 : number of clusters=28 , EVI = 0.944 , sumVI = 0.339 , w= 0.359
## Wasserstein dist = 0.966111948393442
## Iteration = 5
## Particle 1 : number of clusters=26 , EVI = 1.005 , sumVI = 0.286 , w= 0.285
## Particle 2 : number of clusters=27 , EVI = 0.959 , sumVI = 0.36 , w= 0.376
## Particle 3 : number of clusters=28 , EVI = 0.94 , sumVI = 0.319 , w= 0.339
## Wasserstein dist = 0.965703954987267
print(paste('Fixed initialization: Wass dist =',output_WASABI_fxd$wass.dist))
## [1] "Fixed initialization: Wass dist = 0.965703954987267"
toc()
## 408.364 sec elapsed
WASABI provides a number of visualization tools. Let’s first consider the number of weight of the particles.
ggsummary(output_WASABI)
We can also plot the data colored by cluster membership for each particle.
# Create a matrix of normalized data and filter for two regions to draw a scatter plot
data_norm = list(phat, phat2)
r1 = 5
r2 = 9
data_norm_r1r2_list = lapply(data_norm, function(d){d[c(r1,r2),]})
data_norm_r1r2 = matrix(unlist(data_norm_r1r2_list), sum(C),2 ,byrow = TRUE)
ggscatter_grid2d(output_WASABI, data_norm_r1r2) +
labs(x=regions.name[r1], y=regions.name[r2])
To better investigate the differences between any two particles, we can also look at the VI contribution of each point (e.g. particle 1 and particle 2):
p1 = 1
p2 = 2
VIC_p1p2 = vi.contribution(output_WASABI$particles[p1,],output_WASABI$particles[p2,])
meet_p1p2 = cls.meet(output_WASABI$particles[c(p1,p2),])
colors <- rev(sequential_hcl(5, palette = "Purple-Yellow")[1:4])
ggplot() +
geom_point(aes(x = data_norm_r1r2[,1],
y = data_norm_r1r2[,2],
color = VIC_p1p2,
shape = as.factor(meet_p1p2$cls.m))) +
theme_bw() +
#scale_color_distiller(name = "VIC",palette = "OrRd",direction = 1)+
scale_color_gradientn(colours = colors, transform = "sqrt", labels = function(x) sprintf("%.4f", x))+
scale_shape_manual(values=c(1:length(unique(meet_p1p2$cls.m))))+
guides(shape = guide_legend(title="Meet\ncluster")) +
labs(x=regions.name[r1], y=regions.name[r2]) +
ggtitle("VI Contribution between particle 1 and 2")
Alternative useful visualizations of MAPseq data are provided by gel plots (heat maps of the normalized data) and line plots (line plots of the normalized data). In the line plots, we filter to clusters/motifs containing at least 10 neurons, as we are not interested in projection patterns characterized by only a small group of neurons.
# Heat maps
# Illustrate clustering with heatmap of row-normalized data
# Compute VIC with particle 1
VIC_p1p1 = vi.contribution(output_WASABI$particles[1,],output_WASABI$particles[1,])
VIC_p1p2 = vi.contribution(output_WASABI$particles[1,],output_WASABI$particles[2,])
VIC_p1p3 = vi.contribution(output_WASABI$particles[1,],output_WASABI$particles[3,])
lmts = c(0, max(max(VIC_p1p1),max(VIC_p1p2),max(VIC_p1p3)))
p1_list = lapply(1:M,function(m){output_WASABI$particles[1,mouse.index==m]})
ps_p1 = heatmap_VIC(Y = data_barseq, Z = p1_list, regions.name = rownames(data_barseq[[1]]),
vic = VIC_p1p1,
cluster.index = 1:length(unique(output_WASABI$particles[1,])),
title = paste('Particle 1',round(output_WASABI$part.weights[1],3)),
limts = lmts)
p2_list = lapply(1:M,function(m){output_WASABI$particles[2,mouse.index==m]})
ps_p2 = heatmap_VIC(Y = data_barseq, Z = p2_list, regions.name = rownames(data_barseq[[1]]),
vic = VIC_p1p2,
cluster.index = 1:length(unique(output_WASABI$particles[2,])),
title = paste('Particle 2',round(output_WASABI$part.weights[2],3)),
limts = lmts)
p3_list = lapply(1:M,function(m){output_WASABI$particles[3,mouse.index==m]})
ps_p3 = heatmap_VIC(Y = data_barseq, Z = p3_list, regions.name = rownames(data_barseq[[1]]),
vic = VIC_p1p3,
cluster.index = 1:length(unique(output_WASABI$particles[3,])),
title = paste('Particle 3',round(output_WASABI$part.weights[3],3)),
limts = lmts)
ggarrange(ps_p1, ps_p2, ps_p3, ncol=3, nrow=1, common.legend = TRUE, legend="right")
# Heat maps
# Illustrate clustering with heatmap of row-normalized data
# Color by the group VIC
meet_p1p2 = cls.meet(output_WASABI$particles[c(1,2),])$cls.m
VIC_p1p2_group = sapply(1:sum(C), function(c){sum(VIC_p1p2[meet_p1p2 == meet_p1p2[c]])})
meet_p1p3 = cls.meet(output_WASABI$particles[c(1,3),])$cls.m
VIC_p1p3_group = sapply(1:sum(C), function(c){sum(VIC_p1p3[meet_p1p3 == meet_p1p3[c]])})
lmts = c(0, max(max(VIC_p1p1),max(VIC_p1p2),max(VIC_p1p3_group)))
ps_p1 = heatmap_VIC(Y = data_barseq, Z = p1_list, regions.name = rownames(data_barseq[[1]]),
vic = VIC_p1p1,
cluster.index = 1:length(unique(output_WASABI$particles[1,])),
title = paste('Particle 1',round(output_WASABI$part.weights[1],3)),
limts = lmts) +
labs(fill="VICG")
p2_list = lapply(1:M,function(m){output_WASABI$particles[2,mouse.index==m]})
ps_p2 = heatmap_VIC(Y = data_barseq, Z = p2_list, regions.name = rownames(data_barseq[[1]]),
vic = VIC_p1p2_group,
cluster.index = 1:length(unique(output_WASABI$particles[2,])),
title = paste('Particle 2',round(output_WASABI$part.weights[2],3)),
limts = lmts) +
labs(fill="VICG")
p3_list = lapply(1:M,function(m){output_WASABI$particles[3,mouse.index==m]})
ps_p3 = heatmap_VIC(Y = data_barseq, Z = p3_list, regions.name = rownames(data_barseq[[1]]),
vic = VIC_p1p3_group,
cluster.index = 1:length(unique(output_WASABI$particles[3,])),
title = paste('Particle 3',round(output_WASABI$part.weights[3],3)),
limts = lmts) +
labs(fill="VICG")
ggarrange(ps_p1, ps_p2, ps_p3, ncol=3, nrow=1, common.legend = TRUE, legend="right")
# Color line plots by the VIC contribution to highlight differences between particles
VIC_p1p1_list = list()
for (m in 1:M){
VIC_p1p1_list[[m]] = VIC_p1p1[mouse.index==m]
}
VIC_p1p2_list = list()
for (m in 1:M){
VIC_p1p2_list[[m]] = VIC_p1p2[mouse.index==m]
}
VIC_p1p3_list = list()
for (m in 1:M){
VIC_p1p3_list[[m]] = VIC_p1p3[mouse.index==m]
}
# Filter to large enough clusters
mouse.list = lapply(1:M, function(m){rep(as.factor(m),C[m])})
motifs_filter1 = which(table(output_WASABI$particles[1,])>10)
motifs_filter2 = which(table(output_WASABI$particles[2,])>10)
motifs_filter3 = which(table(output_WASABI$particles[3,])>10)
pl_p1_vic = projection_vic(data_barseq, mouse.list, p1_list, regions.name,motifs=motifs_filter1, VIC_p1p1_list, limts = lmts, ncol=7) +
labs(title = 'Particle 1')
pl_p2_vic = projection_vic(data_barseq, mouse.list, p2_list, regions.name,motifs=motifs_filter2, VIC_p1p2_list, limts = lmts, ncol=7) +
labs(title = 'Particle 2')
pl_p3_vic = projection_vic(data_barseq, mouse.list, p3_list, regions.name,motifs=motifs_filter3, VIC_p1p3_list, limts = lmts, ncol=7) +
labs(title = 'Particle 3')
ggarrange(pl_p1_vic, pl_p2_vic, pl_p3_vic, ncol=1, nrow=3, common.legend = TRUE, legend="right")
To label each cluster, we consider that neurons in each group project to a region if the average projection strength is greater than 0.02.
p = 1
data_norm_cbind = t(matrix(unlist(data_norm), sum(C),R ,byrow = TRUE))
part_motif_names <- lapply(sort(unique(output_WASABI$particles[p,])),
function(j){
data.j <- matrix(data_norm_cbind[,output_WASABI$particles[p,] == j],nrow = R,ncol = sum(output_WASABI$particles[p,] == j))
data.j.average <- apply(data.j, 1, mean)
pp.regions <- paste(regions.name[data.j.average >= 0.02], collapse = ',')
pp.weight = sum(output_WASABI$particles[p,] == j)/sum(C)
return(data.frame(cluster = j,
pp.regions = pp.regions,
pp.weight = pp.weight,
pp.strength = data.j.average))
})
part_motif_names <- do.call(rbind, part_motif_names)
print(part_motif_names[seq(1,dim(part_motif_names)[1],11),c(1,2)])
## cluster pp.regions
## 1 1 Amyg
## 12 2 OFC,Motor,Rstr,SSctx,Cstr,Amyg,VisIp,AudC
## 23 3 AudC
## 34 4 Cstr,Amyg,AudC
## 45 5 Rstr,Cstr,Amyg,VisIp,VisC,AudC,Thal
## 56 6 Thal,Tect
## 67 7 Thal
## 78 8 Amyg,VisIp
## 89 9 VisIp
## 100 10 VisIp,VisC,Tect
## 111 11 SSctx,Amyg,AudC
## 122 12 SSctx,VisIp
## 133 13 SSctx,VisIp
## 144 14 VisIp,Thal,Tect
## 155 15 Cstr,VisC,AudC,Thal
## 166 16 OFC,Cstr,Amyg,VisIp
## 177 17 Cstr,Thal,Tect
## 188 18 Cstr,AudC
## 199 19 Thal,Tect
## 210 20 Rstr,Cstr
## 221 21 Cstr,Thal
## 232 22 Cstr,VisIp,AudC
## 243 23 OFC,VisC
## 254 24 Cstr,VisIp,AudC
## 265 25 SSctx,VisIp,AudC
## 276 26 OFC,Rstr,Cstr
## 287 27 VisIp,AudC,Tect
## 298 28 OFC,Motor,Rstr,SSctx,Cstr,Thal,Tect
We can plot the PSM within each region of attraction/neighborhood to understand the uncertainty of each particle.
# PSM within the region of attraction of particle 1
psm_p1 = mcclust::comp.psm(cls.draw[output_WASABI$draws.assign==1,])
hpsm_p1 = superheat(psm_p1,
pretty.order.rows = TRUE,
pretty.order.cols = TRUE,
heat.pal = c("white", "yellow", "red"),
heat.pal.values = c(0,.5,1),
membership.rows = output_WASABI$particles[1,],
membership.cols = output_WASABI$particles[1,],
bottom.label.text.size = 4,
left.label.text.size = 4,
title = "PSM within particle 1's neighborhood")
# PSM within the region of attraction of particle 2
psm_p2 = mcclust::comp.psm(cls.draw[output_WASABI$draws.assign==2,])
hpsm_p2 = superheat(psm_p2,
pretty.order.rows = TRUE,
pretty.order.cols = TRUE,
heat.pal = c("white", "yellow", "red"),
heat.pal.values = c(0,.5,1),
membership.rows = output_WASABI$particles[2,],
membership.cols = output_WASABI$particles[2,],
bottom.label.text.size = 4,
left.label.text.size = 4,
title = "PSM within particle 2's neighborhood")
# PSM within the region of attraction of particle 3
psm_p3 = mcclust::comp.psm(cls.draw[output_WASABI$draws.assign==3,])
hpsm_p3 = superheat(psm_p3,
pretty.order.rows = TRUE,
pretty.order.cols = TRUE,
heat.pal = c("white", "yellow", "red"),
heat.pal.values = c(0,.5,1),
membership.rows = output_WASABI$particles[3,],
membership.cols = output_WASABI$particles[3,],
bottom.label.text.size = 4,
left.label.text.size = 4,
title = "PSM within particle 3's neighborhood")
# ggarrange(hpsm_p1, hpsm_p2, hpsm_p3, ncol=1, nrow=3, common.legend = TRUE, legend="bottom")
We can also find the meet of the particles.
First, we show line plots of neurons in each meet cluster, colored by the contribution to EVI.
output_meet = cls.meet(output_WASABI$particles)
z_meet = output_meet$cls.m
motifs_filter.m = which(table(z_meet)>10)
evi.m = evi.wd.contribution(output_WASABI, z_meet)
meet_list = lapply(1:M,function(m){z_meet[mouse.index==m]})
evi.m_list = list()
for (m in 1:M){
evi.m_list[[m]] = evi.m[mouse.index==m]
}
projection_vic(data_barseq, mouse.list, meet_list, regions.name, motifs=motifs_filter.m, evi.m_list, ncol=5)
The posterior similarity matrix approximated by WASABI and collapsed to the meet clusters helps us to understand which meet clusters are grouped across particles.
# Compute psm of meet clusters
psm.m = psm.meet(z_meet,output_WASABI)
Km <- nrow(psm.m)
colnames(psm.m) <- 1:Km; rownames(psm.m) <- 1:Km
# compare meet with particles
i = 1
part_cl = output_WASABI$particles[i,]
tb_meettop = table(part_cl,z_meet)
lbs_top = rownames(tb_meettop)[as.factor(apply(tb_meettop, 2, which.max))]
tmp = reshape2::melt(as.matrix(as.data.frame.matrix(tb_meettop))) %>%
arrange(Var1) %>% filter(value > 0)
lbs_top = tmp %>% pull(Var1)
tmp = tmp %>% pull(Var2)
superheat::superheat(psm.m[tmp,tmp],
heat.pal = c("white", "yellow", "red"),
heat.pal.values = c(0,.5,1),
heat.lim = c(0,1), # this is important!!
row.title = paste('Particle',i),
column.title = paste('Meet'),
membership.rows = as.numeric(lbs_top),
membership.cols = tmp,
bottom.label.text.angle = 90,
bottom.label.text.size = 3,
left.label.text.size = 3)
#ggsave(pm$plot,filename = "psm_meet_ac.png", device = "png",
# width = 7.5, height = 8,units = "in", scale = 1)
Let’s investigate the sizes of the meet clusters, colored by the their total EVI contribution (sum across neurons in the meet cluster). This helps us to understand which meet clusters are stable (groups of neurons with distinct projection patterns), and which may be more uncertain.
tmp = reshape2::melt(as.matrix(as.data.frame.matrix(tb_meettop))) %>%
arrange(Var1) %>% filter(value > 0)
evi.m.group = sapply(unique(tmp$Var2), function(m){max(sum(evi.m[z_meet==m]),0)})
evi.m.unique= sapply(unique(tmp$Var2), function(m){unique(evi.m[z_meet==m])})
df = data.frame(cluster = factor(tmp$Var2, levels = tmp$Var2 ), size = tmp$value, EVI = evi.m.group)
ggplot(df) +
geom_col(aes(x = cluster, y = size, fill = EVI)) +
theme_bw() +
scale_fill_gradientn(colours = colors, transform = "sqrt", labels = function(x) sprintf("%.4f", x))+
theme(axis.text.x = element_text(size = 10, angle = 90, vjust = 0.5),
axis.title.x = element_text(size = 12),
axis.title.y = element_text(size = 12,angle = 90)) +
labs(fill = "EVICG")
To visualize the stable meet clusters, we focus on those with at least 10 neurons and the EVI by group less than 0.002.
evi.m.group_list = list()
for (m in 1:M){
evi.m.group_list[[m]] = evi.m.group[sort(unique(tmp$Var2), index.return = T)$ix][z_meet[mouse.index==m]]
}
motifs_filter.m = df$cluster[(evi.m.group<0.002)&(df$size>10)]
projection_vic(data_barseq, mouse.list, meet_list, regions.name, motifs=motifs_filter.m, evi.m.group_list, ncol=5, limts = c(0, max(evi.m.group))) + labs(color = "EVICG")
We can also look more carefully at some of the other meet clusters, for example those that form the noisy cluster 2 in particle 1.
# Clusters 2 of Particle 1
motifs_filter.m = tmp[tmp$Var1==2,2]
projection_vic(data_barseq, mouse.list, meet_list, regions.name, motifs=motifs_filter.m, evi.m.group_list, ncol=5, limts = c(0, max(evi.m.group))) + labs(color = "EVICG")
Another example is the meet clusters that form cluster 4 or 5 in particle 1.
# Clusters 4 of Particle 1
motifs_filter.m = tmp$Var2[tmp$Var1==4]
projection_vic(data_barseq, mouse.list, meet_list, regions.name, motifs=motifs_filter.m, evi.m.group_list, ncol=3, limts = c(0, max(evi.m.group))) + labs(color = "EVICG")
# Clusters 5 of Particle 1
motifs_filter.m = tmp$Var2[tmp$Var1==5]
projection_vic(data_barseq, mouse.list, meet_list, regions.name, motifs=motifs_filter.m, evi.m.group_list, ncol=3, limts = c(0, max(evi.m.group))) + labs(color = "EVICG")
Let’s also investigate the Thal and Tect clusters: - Cluster 19 of particle 1 has moderate projection to Thal and Tect only - Cluster 6 of particle also projects to Thal and Tect but with higher strength to Thal - Cluster 17 of particle 1 also projects to Thal and Tect but with weak strength also to Cstr
# Clusters 19 of Particle 1
motifs_filter.m = tmp$Var2[tmp$Var1==19]
projection_vic(data_barseq, mouse.list, meet_list, regions.name, motifs=motifs_filter.m, evi.m.group_list, ncol=4, limts = c(0, max(evi.m.group))) + labs(color = "EVICG")
# Clusters 6 of Particle 1
motifs_filter.m = tmp$Var2[tmp$Var1==6]
projection_vic(data_barseq, mouse.list, meet_list, regions.name, motifs=motifs_filter.m, evi.m.group_list, ncol=3, limts = c(0, max(evi.m.group))) + labs(color = "EVICG")
# Clusters 17 of Particle 1
motifs_filter.m = tmp$Var2[tmp$Var1==17]
projection_vic(data_barseq, mouse.list, meet_list, regions.name, motifs=motifs_filter.m, evi.m.group_list, ncol=3, limts = c(0, max(evi.m.group))) + labs(color = "EVICG")
meet_motif_names <- lapply(sort(unique(z_meet)),
function(j){
data.j <- matrix(data_norm_cbind[,z_meet == j],nrow = R,ncol = sum(z_meet==j))
data.j.average <- apply(data.j, 1, mean)
pp.regions <- paste(regions.name[data.j.average >= 0.02], collapse = ',')
pp.weight = sum(z_meet == j)/sum(C)
return(data.frame(cluster = j,
pp.regions = pp.regions,
pp.weight = pp.weight,
pp.strength = data.j.average))
})
meet_motif_names <- do.call(rbind, meet_motif_names)
print(meet_motif_names[seq(1,dim(meet_motif_names)[1],11),c(1,2)])
## cluster pp.regions
## 1 1 Amyg
## 12 2 OFC,Motor,Rstr,SSctx,Cstr,Amyg,VisIp
## 23 3 Motor,SSctx,Cstr,Amyg,VisIp,AudC
## 34 4 AudC
## 45 5 VisIp,VisC,AudC
## 56 6 Cstr,Amyg,VisIp,AudC
## 67 7 VisC,AudC
## 78 8 Cstr,VisC,AudC
## 89 9 Rstr,SSctx,Cstr,Amyg,VisIp,VisC,AudC,Thal
## 100 10 Cstr,VisIp,VisC,AudC,Thal,Tect
## 111 11 Rstr,Cstr,Amyg,AudC
## 122 12 Cstr,Amyg,AudC
## 133 13 Thal,Tect
## 144 14 Thal,Tect
## 155 15 Thal,Tect
## 166 16 Thal
## 177 17 Amyg,VisIp
## 188 18 SSctx,VisIp
## 199 19 SSctx,VisIp
## 210 20 VisIp,VisC,Tect
## 221 21 SSctx,Amyg,AudC
## 232 22 VisIp
## 243 23 SSctx,VisIp
## 254 24 Rstr,Cstr,AudC
## 265 25 VisIp,Thal,Tect
## 276 26 Cstr,AudC
## 287 27 Rstr,SSctx,Cstr
## 298 28 OFC,Cstr,Amyg,VisIp
## 309 29 Cstr,Tect
## 320 30 Rstr,Cstr
## 331 31 Cstr,VisIp,AudC
## 342 32 Cstr,VisC,AudC,Thal
## 353 33 Thal,Tect
## 364 34 Thal,Tect
## 375 35 Cstr,Thal,Tect
## 386 36 Cstr,Thal,Tect
## 397 37 Cstr,Thal,Tect
## 408 38 Cstr,Thal,Tect
## 419 39 Cstr,Thal,Tect
## 430 40 Cstr,AudC
## 441 41 Cstr,VisC,AudC,Thal,Tect
## 452 42 Cstr,AudC
## 463 43 Thal,Tect
## 474 44 Thal,Tect
## 485 45 Cstr,Thal
## 496 46 Cstr,Thal
## 507 47 Cstr,VisIp,AudC
## 518 48 Cstr,VisIp,AudC
## 529 49 OFC,VisC
## 540 50 Motor,SSctx
## 551 51 SSctx,VisIp,AudC
## 562 52 OFC,Rstr,Cstr
## 573 53 VisIp,AudC,Tect
## 584 54 OFC,Motor,Rstr,SSctx,Cstr,Thal,Tect